{
 "cells": [
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "# RNN从零开始编程",
   "id": "16618b1a7786d3f8"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## 1. 下载数据",
   "id": "7cbced3ea6baf9f9"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-08-28T06:33:56.981961Z",
     "start_time": "2024-08-28T06:33:55.279855Z"
    }
   },
   "cell_type": "code",
   "source": [
    "%matplotlib inline\n",
    "import math\n",
    "import torch\n",
    "from torch import nn\n",
    "from torch.nn import functional as F\n",
    "from d2l import torch as d2l\n",
    "\n",
    "batch_size, num_steps = 32, 35\n",
    "train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)"
   ],
   "id": "d2a7603a840708bd",
   "execution_count": 2,
   "outputs": []
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-08-28T06:36:52.318623Z",
     "start_time": "2024-08-28T06:36:52.314024Z"
    }
   },
   "cell_type": "code",
   "source": [
    "# 字符对应的标号\n",
    "vocab.token_to_idx"
   ],
   "id": "9e2ec916d498e9e4",
   "execution_count": 15,
   "outputs": []
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## 2. 初始化模型参数",
   "id": "47e9df394d63e2d0"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-08-28T06:57:57.374487Z",
     "start_time": "2024-08-28T06:57:57.369091Z"
    }
   },
   "cell_type": "code",
   "source": [
    "def get_params(vocab_size, num_hiddens, device):\n",
    "    num_inputs = num_outputs = vocab_size\n",
    "\n",
    "    def normal(shape):\n",
    "        return torch.randn(size=shape, device=device) * 0.01\n",
    "\n",
    "    # 隐藏层参数\n",
    "    W_xh = normal((num_inputs, num_hiddens))\n",
    "    W_hh = normal((num_hiddens, num_hiddens))\n",
    "    b_h = torch.zeros(num_hiddens, device=device)\n",
    "    # 输出层参数\n",
    "    W_hq = normal((num_hiddens, num_outputs))\n",
    "    b_q = torch.zeros(num_outputs, device=device)\n",
    "    # 附加梯度\n",
    "    params = [W_xh, W_hh, b_h, W_hq, b_q]\n",
    "    for param in params:\n",
    "        param.requires_grad_(True)\n",
    "    return params"
   ],
   "id": "733bc07e4e0f388c",
   "execution_count": 23,
   "outputs": []
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## 3. 循环神经网络模型",
   "id": "511f4a4769908fb4"
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "为了定义循环神经网络模型， 我们首先需要一个`init_rnn_state`函数在初始化时返回隐状态。 这个函数的返回是一个张量，张量全用0填充， 形状为`（批量大小，隐藏单元数）`。 在后面的章节中我们将会遇到隐状态包含多个变量的情况， 而使用元组可以更容易地处理些。",
   "id": "74632cc08c75086f"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-08-28T07:04:57.553481Z",
     "start_time": "2024-08-28T07:04:57.543004Z"
    }
   },
   "cell_type": "code",
   "source": [
    "def init_rnn_state(batch_size, num_hiddens, device):\n",
    "    return (torch.zeros((batch_size, num_hiddens), device=device),)"
   ],
   "id": "dbf7e60c2c1d7baa",
   "execution_count": 24,
   "outputs": []
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "下面的rnn函数定义了如何在一个时间步内计算隐状态和输出。 循环神经网络模型通过inputs最外层的维度实现循环， 以便逐时间步更新小批量数据的隐状态H。 此外，这里使用函数作为激活函数。 如 4.1节所述， 当元素在实数上满足均匀分布时，函数的平均值为0。",
   "id": "653a647302d497c4"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-08-28T07:42:36.433266Z",
     "start_time": "2024-08-28T07:42:36.421954Z"
    }
   },
   "cell_type": "code",
   "source": [
    "def rnn(inputs, state, params):\n",
    "    '''\n",
    "    RNN cell\n",
    "    :param inputs: 批量输入特征 形状：(时间步数量，批量大小，词表大小)\n",
    "    :param state: 初始状态\n",
    "    :param params: 可学习参数\n",
    "    :return: 当前的输出和当前的隐藏状态\n",
    "    '''\n",
    "    # inputs的形状：(时间步数量，批量大小，词表大小)\n",
    "    W_xh, W_hh, b_h, W_hq, b_q = params\n",
    "    H, = state\n",
    "    outputs = []\n",
    "    # X的形状：(批量大小，词表大小)\n",
    "    for X in inputs:\n",
    "        # 时刻0的样本和样本字符列表，时刻1的样本和样本字符列表，...\n",
    "        H = torch.tanh(torch.mm(X, W_xh) + torch.mm(H, W_hh) + b_h)\n",
    "        Y = torch.mm(H, W_hq) + b_q  # 当前时刻的输入和隐变量预测出的下一时刻的输出\n",
    "        outputs.append(Y)\n",
    "    return torch.cat(outputs, dim=0), (H,)"
   ],
   "id": "d174adf04a9b7308",
   "execution_count": 25,
   "outputs": []
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "定义了所有需要的函数之后，接下来我们创建一个类来包装这些函数， 并存储从零开始实现的循环神经网络模型的参数",
   "id": "588baac61c6aab45"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-08-28T07:49:40.045316Z",
     "start_time": "2024-08-28T07:49:40.031463Z"
    }
   },
   "cell_type": "code",
   "source": [
    "class RNNModelScratch:  #@save\n",
    "    \"\"\"从零开始实现的循环神经网络模型\"\"\"\n",
    "\n",
    "    def __init__(self, vocab_size, num_hiddens, device, get_params, init_state, forward_fn):\n",
    "        self.vocab_size, self.num_hiddens = vocab_size, num_hiddens\n",
    "        self.params = get_params(vocab_size, num_hiddens, device)\n",
    "        self.init_state, self.forward_fn = init_state, forward_fn\n",
    "        # forward_fn前向传播函数，这里可以传刚才定义的rnn函数\n",
    "\n",
    "    def __call__(self, X, state):\n",
    "        X = F.one_hot(X.T, self.vocab_size).type(torch.float32)\n",
    "        return self.forward_fn(X, state, self.params)\n",
    "\n",
    "    def begin_state(self, batch_size, device):\n",
    "        return self.init_state(batch_size, self.num_hiddens, device)"
   ],
   "id": "a964f00c8919a0c8",
   "execution_count": 26,
   "outputs": []
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "检查输出是否具有正确的形状",
   "id": "f9b54409e0bae919"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-08-28T07:53:48.107534Z",
     "start_time": "2024-08-28T07:53:48.042895Z"
    }
   },
   "cell_type": "code",
   "source": [
    "num_hiddens = 512\n",
    "net = RNNModelScratch(\n",
    "    len(vocab),\n",
    "    num_hiddens,\n",
    "    d2l.try_gpu(),\n",
    "    get_params,\n",
    "    init_rnn_state,\n",
    "    rnn\n",
    ")\n",
    "\n",
    "# 生成检查数据\n",
    "X = torch.arange(10).reshape((2, 5))\n",
    "# 初始化第0个隐藏状态\n",
    "state = net.begin_state(X.shape[0], d2l.try_gpu())\n",
    "\n",
    "Y, new_state = net(X.to(d2l.try_gpu()), state)\n",
    "Y.shape, len(new_state), new_state[0].shape"
   ],
   "id": "c754b429cf85d459",
   "execution_count": 28,
   "outputs": []
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## 4. 预测",
   "id": "82823e2fca66de22"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-08-28T08:13:42.085508Z",
     "start_time": "2024-08-28T08:13:42.066974Z"
    }
   },
   "cell_type": "code",
   "source": [
    "def predict_ch8(prefix, num_preds, net, vocab, device):  #@save\n",
    "    \"\"\"\n",
    "    \n",
    "    :param prefix: prefix是一个用户提供的包含多个字符的字符串，该函数根据这个prefix生成之后的字符串 \n",
    "    :param num_preds: 需要生成的字符/字符串长度\n",
    "    :param net: 网络模型\n",
    "    :param vocab: 词汇表，根据vacab实现one-hot编码向字符的转变\n",
    "    :param device: cpu或gpu\n",
    "    :return: 生成的字符串\n",
    "    \"\"\"\n",
    "    \"\"\"在prefix后面生成新字符\"\"\"\n",
    "\n",
    "    state = net.begin_state(batch_size=1, device=device)  # batch_size=1 是因为只要生成一个字符串\n",
    "    outputs = [vocab[prefix[0]]]  # 只是把用户输入的字符对应的标号放在开头而已，并无特殊含义 e.g. [4, ]\n",
    "\n",
    "    # 定义函数get_input()，每次拿到outputs的最后一个词，转换成tensor\n",
    "    get_input = lambda: torch.tensor([outputs[-1]], device=device).reshape((1, 1))\n",
    "    # def get_input():\n",
    "    #     return torch.tensor([outputs[-1]], device=device).reshape((1, 1))\n",
    "\n",
    "    for y in prefix[1:]:  # 预热期，利用输入的prefix的信息不断完善隐状态state，并把完整的prefix存入outputs，并不在意预测值\n",
    "        _, state = net(get_input(), state)\n",
    "        outputs.append(vocab[y])  # e.g. [4, 3, 2, ...]\n",
    "\n",
    "    # 预测num_preds步\n",
    "    for _ in range(num_preds):\n",
    "        y, state = net(get_input(), state)\n",
    "        outputs.append(int(y.argmax(dim=1).reshape(1)))  # 分类问题，找到len(vocab)个元素中最大的的下标，加入outputs\n",
    "\n",
    "    return ''.join([vocab.idx_to_token[i] for i in outputs])  # vocab.idx_to_token[i]在词汇表中把标号转为字符，然后输出"
   ],
   "id": "c224e8514df0d46",
   "execution_count": 29,
   "outputs": []
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "现在我们可以测试predict_ch8函数。 我们将前缀指定为time traveller， 并基于这个前缀生成10个后续字符。 鉴于我们还没有训练网络，它会生成荒谬的预测结果。",
   "id": "8c54f40adddc09ac"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-08-28T08:15:36.874842Z",
     "start_time": "2024-08-28T08:15:36.864150Z"
    }
   },
   "cell_type": "code",
   "source": [
    "predict_ch8(\n",
    "    prefix='time traveller ',\n",
    "    num_preds=10,\n",
    "    net=net,\n",
    "    vocab=vocab,\n",
    "    device=d2l.try_gpu()\n",
    ")"
   ],
   "id": "4c30aff234108ecf",
   "execution_count": 31,
   "outputs": []
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## 辅助函数：梯度裁剪",
   "id": "7a41290e75b2ef43"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-08-28T08:25:55.337591Z",
     "start_time": "2024-08-28T08:25:55.331228Z"
    }
   },
   "cell_type": "code",
   "source": [
    "def grad_clipping(net, theta):  #@save\n",
    "    \"\"\"裁剪梯度\"\"\"\n",
    "\n",
    "    # 拿出所有需要学习的参数，即需要requires_grad==True\n",
    "    if isinstance(net, nn.Module):  # 如果采用简洁写法nn.Module\n",
    "        params = [p for p in net.parameters() if p.requires_grad]\n",
    "    else:  # 如果采用从零开始实现的方法\n",
    "        params = net.params\n",
    "\n",
    "    # 先把所有的学习参数拼成一个向量，然后直接求L2范数\n",
    "    norm = torch.sqrt(sum(torch.sum((p.grad ** 2)) for p in params))\n",
    "\n",
    "    if norm > theta:\n",
    "        # 梯度太大，就把所有梯度同时缩放\n",
    "        for param in params:\n",
    "            param.grad[:] *= theta / norm"
   ],
   "id": "9dc5f51273b64d57",
   "execution_count": 33,
   "outputs": []
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## 5. 开始训练",
   "id": "e301401c1647b2b7"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-08-28T08:45:06.724406Z",
     "start_time": "2024-08-28T08:45:06.714348Z"
    }
   },
   "cell_type": "code",
   "source": [
    "def train_epoch_ch8(net, train_iter, loss, updater, device, use_random_iter):\n",
    "    '''\n",
    "    \n",
    "    :param net: 模型\n",
    "    :param train_iter: 数据 \n",
    "    :param loss: 损失\n",
    "    :param updater: 优化器\n",
    "    :param device: cpu或gpu\n",
    "    :param use_random_iter: True随机抽样，False顺序分区 \n",
    "    :return: 困惑率 和 速度\n",
    "    '''\n",
    "    \"\"\"训练网络一个迭代周期（定义见第8章）\"\"\"\n",
    "    state, timer = None, d2l.Timer()\n",
    "    metric = d2l.Accumulator(2)  # 训练损失之和,词元数量\n",
    "\n",
    "    # 计算图 x -> s -> y 如果s.detach_() 则 x, s -> y 且设置s.requires_grad=False 即y.backward时不会求s的梯度\n",
    "    for X, Y in train_iter:\n",
    "        if state is None or use_random_iter:\n",
    "            # 在第一次迭代初始化state 或 使用随机抽样时初始化state，因为对于随机抽样上一个state与这一次相互独立\n",
    "            state = net.begin_state(batch_size=X.shape[0], device=device)\n",
    "        else:\n",
    "            if isinstance(net, nn.Module) and not isinstance(state, tuple):\n",
    "                # state对于nn.GRU是个张量\n",
    "                state.detach_()\n",
    "            else:\n",
    "                # state对于nn.LSTM或对于我们从零开始实现的模型是个张量\n",
    "                for s in state:\n",
    "                    s.detach_()\n",
    "\n",
    "        y = Y.T.reshape(-1)\n",
    "        X, y = X.to(device), y.to(device)\n",
    "        y_hat, state = net(X, state)\n",
    "        # y_hat和y都是batch_size * nun_steps长的tensor\n",
    "        l = loss(y_hat, y.long()).mean()\n",
    "\n",
    "        if isinstance(updater, torch.optim.Optimizer):\n",
    "            updater.zero_grad()\n",
    "            l.backward()\n",
    "            grad_clipping(net, 1)  # 裁剪梯度 theta=1\n",
    "            updater.step()\n",
    "        else:\n",
    "            l.backward()\n",
    "            grad_clipping(net, 1)\n",
    "            # 因为已经调用了mean函数\n",
    "            updater(batch_size=1)\n",
    "\n",
    "        # metric[0]存储总损失，metric[1]存储样本数量\n",
    "        metric.add(l * y.numel(), y.numel())\n",
    "    return math.exp(metric[0] / metric[1]), metric[1] / timer.stop()"
   ],
   "id": "129264d105cea3fa",
   "execution_count": 35,
   "outputs": []
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "循环神经网络模型的训练函数既支持从零开始实现， 也可以使用高级API来实现",
   "id": "da993e8a9d44e43d"
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-08-28T08:52:49.918387Z",
     "start_time": "2024-08-28T08:52:49.904287Z"
    }
   },
   "cell_type": "code",
   "source": [
    "def train_ch8(net, train_iter, vocab, lr, num_epochs, device, use_random_iter=False):\n",
    "    \"\"\"训练模型（定义见第8章）\"\"\"\n",
    "    loss = nn.CrossEntropyLoss()  # 实际上就是分类模型\n",
    "\n",
    "    # 画图\n",
    "    animator = d2l.Animator(xlabel='epoch', ylabel='perplexity', legend=['train'], xlim=[10, num_epochs])\n",
    "\n",
    "    # 初始化\n",
    "    # 优化器\n",
    "    if isinstance(net, nn.Module):\n",
    "        updater = torch.optim.SGD(net.parameters(), lr)\n",
    "    else:\n",
    "        updater = lambda batch_size: d2l.sgd(net.params, lr, batch_size)\n",
    "\n",
    "    predict = lambda prefix: predict_ch8(prefix, 50, net, vocab, device)\n",
    "\n",
    "    # 训练和预测\n",
    "    for epoch in range(num_epochs):\n",
    "        ppl, speed = train_epoch_ch8(net, train_iter, loss, updater, device, use_random_iter)\n",
    "        if (epoch + 1) % 10 == 0:\n",
    "            # 以'time traveller'为开头的训练\n",
    "            print(predict('time traveller'))\n",
    "            animator.add(epoch + 1, [ppl])\n",
    "\n",
    "    print(f'困惑度 {ppl:.1f}, {speed:.1f} 词元/秒 {str(device)}')\n",
    "\n",
    "    print(predict('time traveller'))  # 预测'time traveller'之后的生成\n",
    "    print(predict('traveller'))  # 预测'traveller'之后的生成"
   ],
   "id": "cfda9b3f4ba48e68",
   "execution_count": 37,
   "outputs": []
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "num_epochs, lr = 500, 1\n",
    "train_ch8(net, train_iter, vocab, lr, num_epochs, d2l.try_gpu())"
   ],
   "id": "7ba47c221c534724",
   "execution_count": 38,
   "outputs": []
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-08-28T08:55:44.868351Z",
     "start_time": "2024-08-28T08:55:44.842071Z"
    }
   },
   "cell_type": "code",
   "source": [
    "predict = lambda prefix: predict_ch8(prefix, 50, net, vocab, d2l.try_gpu())\n",
    "print(predict('time traveller'))"
   ],
   "id": "675ef75d72fb783e",
   "execution_count": 40,
   "outputs": []
  },
  {
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-08-28T08:57:06.817989Z",
     "start_time": "2024-08-28T08:57:06.812554Z"
    }
   },
   "cell_type": "code",
   "source": [
    "predict = lambda prefix: predict_ch8(prefix, 50, net, vocab, d2l.try_gpu())\n",
    "print(predict('traveller'))"
   ],
   "id": "39fc1667786e046e",
   "execution_count": 44,
   "outputs": []
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "随机抽样方法的结果",
   "id": "50264d977e62cbdc"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "net = RNNModelScratch(len(vocab), num_hiddens, d2l.try_gpu(), get_params, init_rnn_state, rnn)\n",
    "train_ch8(net, train_iter, vocab, lr, num_epochs, d2l.try_gpu(), use_random_iter=True)"
   ],
   "id": "c95c730e16968a9c",
   "execution_count": 45,
   "outputs": []
  },
  {
   "metadata": {},
   "cell_type": "code",
   "execution_count": null,
   "source": "",
   "id": "12b3fdd172b9d106",
   "outputs": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
